{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d3323204-1463-4df3-8c75-5e95b6d66ba1",
   "metadata": {},
   "source": [
    "# Creating a Llama-3 LoRA adapter with NeMo Framework"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29f3d632-44a0-4e6c-9229-b70bbcff1e99",
   "metadata": {},
   "source": [
    "This notebook showcases performing LoRA PEFT **Llama 3 8B** on [PubMedQA](https://pubmedqa.github.io/) using NeMo Framework. PubMedQA is a Question-Answering dataset for biomedical texts.\n",
    "\n",
    "> `NOTE:` Ensure that you run this notebook inside the [NeMo Framework container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) which has all the required dependencies. **Instructions are available in the associated tutorial README to download the model and the container.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50de4d53",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip install ipywidgets"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b285d5a-d838-423b-9d6c-65add61f48ce",
   "metadata": {},
   "source": [
    "---\n",
    "## Before you begin\n",
    "Ensure that you have the `Meta Llama3 8B Instruct .nemo` model downloaded and the corresponding folder mounted to the container."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3057e525-7957-45c0-bedc-c347d4811081",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!ls /workspace/llama-3-8b-instruct-nemo_v1.0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "deb6a910-a05e-4ae1-aac4-56e5092be2b4",
   "metadata": {
    "tags": []
   },
   "source": [
    "---\n",
    "##  Step-by-step instructions\n",
    "\n",
    "This notebook is structured into four steps:\n",
    "1. Prepare the dataset\n",
    "2. Run the PEFT finetuning script\n",
    "3. Inference with NeMo Framework\n",
    "4. Check the model accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ea5bd31",
   "metadata": {},
   "source": [
    "### Step 1: Prepare the dataset\n",
    "\n",
    "Download the PubMedQA dataset and run the pre-processing script in the cloned directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "944b43c5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "\n",
    "# Download the dataset and prep. scripts\n",
    "git clone https://github.com/pubmedqa/pubmedqa.git\n",
    "\n",
    "# split it into train/val/test datasets\n",
    "cd pubmedqa/preprocess\n",
    "python split_dataset.py pqal"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8025b2d4",
   "metadata": {},
   "source": [
    "The following example shows what a single row looks inside of the PubMedQA train, validation and test splits.\n",
    "\n",
    "```json\n",
    "\"18251357\": {\n",
    "    \"QUESTION\": \"Does histologic chorioamnionitis correspond to clinical chorioamnionitis?\",\n",
    "    \"CONTEXTS\": [\n",
    "        \"To evaluate the degree to which histologic chorioamnionitis, a frequent finding in placentas submitted for histopathologic evaluation, correlates with clinical indicators of infection in the mother.\",\n",
    "        \"A retrospective review was performed on 52 cases with a histologic diagnosis of acute chorioamnionitis from 2,051 deliveries at University Hospital, Newark, from January 2003 to July 2003. Third-trimester placentas without histologic chorioamnionitis (n = 52) served as controls. Cases and controls were selected sequentially. Maternal medical records were reviewed for indicators of maternal infection.\",\n",
    "        \"Histologic chorioamnionitis was significantly associated with the usage of antibiotics (p = 0.0095) and a higher mean white blood cell count (p = 0.018). The presence of 1 or more clinical indicators was significantly associated with the presence of histologic chorioamnionitis (p = 0.019).\"\n",
    "    ],\n",
    "    \"reasoning_required_pred\": \"yes\",\n",
    "    \"reasoning_free_pred\": \"yes\",\n",
    "    \"final_decision\": \"yes\",\n",
    "    \"LONG_ANSWER\": \"Histologic chorioamnionitis is a reliable indicator of infection whether or not it is clinically apparent.\"\n",
    "},\n",
    "```\n",
    "\n",
    "Use the following code to convert the train, validation, and test PubMedQA data into the `JSONL` format that NeMo needs for PEFT."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90f69729",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "def read_jsonl(fname):\n",
    "    obj = []\n",
    "    with open(fname, 'rt') as f:\n",
    "        st = f.readline()\n",
    "        while st:\n",
    "            obj.append(json.loads(st))\n",
    "            st = f.readline()\n",
    "    return obj\n",
    "\n",
    "def write_jsonl(fname, json_objs):\n",
    "    with open(fname, 'wt') as f:\n",
    "        for o in json_objs:\n",
    "            f.write(json.dumps(o)+\"\\n\")\n",
    "            \n",
    "def form_question(obj):\n",
    "    st = \"\"    \n",
    "    for i, label in enumerate(obj['LABELS']):\n",
    "        st += f\"{label}: {obj['CONTEXTS'][i]}\\n\"\n",
    "    st += f\"QUESTION: {obj['QUESTION']}\\n\"\n",
    "    st += f\" ### ANSWER (yes|no|maybe): \"\n",
    "    return st\n",
    "\n",
    "def convert_to_jsonl(data_path, output_path):\n",
    "    data = json.load(open(data_path, 'rt'))\n",
    "    json_objs = []\n",
    "    for k in data.keys():\n",
    "        obj = data[k]\n",
    "        prompt = form_question(obj)\n",
    "        completion = obj['final_decision']\n",
    "        json_objs.append({\"input\": prompt, \"output\": f\"<<< {completion} >>>\"})\n",
    "    write_jsonl(output_path, json_objs)\n",
    "    return json_objs\n",
    "\n",
    "\n",
    "test_json_objs = convert_to_jsonl(\"pubmedqa/data/test_set.json\", \"pubmedqa/data/pubmedqa_test.jsonl\")\n",
    "train_json_objs = convert_to_jsonl(\"pubmedqa/data/pqal_fold0/train_set.json\", \"pubmedqa/data/pubmedqa_train.jsonl\")\n",
    "dev_json_objs = convert_to_jsonl(\"pubmedqa/data/pqal_fold0/dev_set.json\", \"pubmedqa/data/pubmedqa_val.jsonl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62777542",
   "metadata": {},
   "source": [
    "> `Note:` In the output, we enforce the inclusion of “<<<” and “>>>“ markers which would allow verification of the LoRA tuned model during inference. This is  because the base model can produce “yes” / “no” responses based on zero-shot templates as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04a3fc36",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# clear up cached mem-map file\n",
    "!rm pubmedqa/data/*idx*"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ddd0f2a",
   "metadata": {},
   "source": [
    "After running the above script, you will see  `pubmedqa_train.jsonl`, `pubmedqa_val.jsonl`, and `pubmedqa_test.jsonl` files appear in the data directory.\n",
    "\n",
    "This is what an example will be formatted like after the script has converted the PubMedQA data into `JSONL` -\n",
    "\n",
    "```json\n",
    "{\"input\": \"QUESTION: Failed IUD insertions in community practice: an under-recognized problem?\\nCONTEXT: The data analysis was conducted to describe the rate of unsuccessful copper T380A intrauterine device (IUD) insertions among women using the IUD for emergency contraception (EC) at community family planning clinics in Utah.\\n ...  ### ANSWER (yes|no|maybe): \",\n",
    "\"output\": \"<<< yes >>>\"}\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0eb1d887",
   "metadata": {},
   "source": [
    "\n",
    "### Step 2: Run PEFT finetuning script for LoRA\n",
    "\n",
    "NeMo framework includes a high level python script for fine-tuning  [megatron_gpt_finetuning.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py) that can abstract away some of the lower level API calls. Once you have your model downloaded and the dataset ready, LoRA fine-tuning with NeMo is essentially just running this script!\n",
    "\n",
    "For this demonstration, this training run is capped by `max_steps`, and validation is carried out every `val_check_interval` steps. If the validation loss does not improve after a few checks, training is halted to avoid overfitting.\n",
    "\n",
    "> `NOTE:` In the block of code below, pass the paths to your train, test and validation data files as well as path to the .nemo model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2c129f9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "\n",
    "# Set paths to the model, train, validation and test sets.\n",
    "MODEL=\"/workspace/llama-3-8b-instruct-nemo_v1.0/8b_instruct_nemo_bf16.nemo\"\n",
    "TRAIN_DS=\"[./pubmedqa/data/pubmedqa_train.jsonl]\"\n",
    "VALID_DS=\"[./pubmedqa/data/pubmedqa_val.jsonl]\"\n",
    "TEST_DS=\"[./pubmedqa/data/pubmedqa_test.jsonl]\"\n",
    "TEST_NAMES=\"[pubmedqa]\"\n",
    "\n",
    "SCHEME=\"lora\"\n",
    "TP_SIZE=1\n",
    "PP_SIZE=1\n",
    "\n",
    "OUTPUT_DIR=\"./results/Meta-Llama-3-8B-Instruct\"\n",
    "rm -r $OUTPUT_DIR\n",
    "\n",
    "torchrun --nproc_per_node=1 \\\n",
    "/opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \\\n",
    "    exp_manager.exp_dir=${OUTPUT_DIR} \\\n",
    "    exp_manager.explicit_log_dir=${OUTPUT_DIR} \\\n",
    "    trainer.devices=1 \\\n",
    "    trainer.num_nodes=1 \\\n",
    "    trainer.precision=bf16-mixed \\\n",
    "    trainer.val_check_interval=20 \\\n",
    "    trainer.max_steps=500 \\\n",
    "    model.megatron_amp_O2=True \\\n",
    "    ++model.mcore_gpt=True \\\n",
    "    model.tensor_model_parallel_size=${TP_SIZE} \\\n",
    "    model.pipeline_model_parallel_size=${PP_SIZE} \\\n",
    "    model.micro_batch_size=1 \\\n",
    "    model.global_batch_size=8 \\\n",
    "    model.restore_from_path=${MODEL} \\\n",
    "    model.data.train_ds.num_workers=0 \\\n",
    "    model.data.validation_ds.num_workers=0 \\\n",
    "    model.data.train_ds.file_names=${TRAIN_DS} \\\n",
    "    model.data.train_ds.concat_sampling_probabilities=[1.0] \\\n",
    "    model.data.validation_ds.file_names=${VALID_DS} \\\n",
    "    model.peft.peft_scheme=${SCHEME}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf4331fd-da30-4e29-8477-3085118e4a7b",
   "metadata": {},
   "source": [
    "This will create a LoRA adapter - a file named `megatron_gpt_peft_lora_tuning.nemo` in `./results/Meta-Llama-3-8B-Instruct/checkpoints/`. We'll use this later.\n",
    "\n",
    "To further configure the run above -\n",
    "\n",
    "* **A different PEFT technique**: The `peft.peft_scheme` parameter determines the technique being used. In this case, we did LoRA, but NeMo Framework supports other techniques as well - such as P-tuning, Adapters, and IA3. For more information, refer to the [PEFT support matrix](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/nlp/nemo_megatron/peft/landing_page.html). For example, for P-tuning, simply set \n",
    "\n",
    "```bash\n",
    "model.peft.peft_scheme=\"ptuning\" # instead of \"lora\"\n",
    "```\n",
    "\n",
    "* **Tuning Llama-3 70B**: You will need 8xA100 or 8xH100 GPUs. Provide the path to it's .nemo checkpoint (similar to the download and conversion steps earlier), and change the model parallelization settings for Llama-3 70B PEFT to distribute across the GPUs. It is also recommended to run the fine-tuning script from a terminal directly instead of Jupyter when using more than 1 GPU.\n",
    "```bash\n",
    "model.tensor_model_parallel_size=8\n",
    "model.pipeline_model_parallel_size=1\n",
    "```\n",
    "\n",
    "You can override many such configurations while running the script. A full set of possible configurations is located in [NeMo Framework Github](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53979a4d",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Step 3: Inference with NeMo Framework\n",
    "\n",
    "Running text generation within the framework is also possible with running a Python script. Note that is more for testing and validation, not a full-fledged  deployment solution like NVIDIA NIM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00d1e3f8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Check that the LORA model file exists\n",
    "!ls -l ./results/Meta-Llama-3-8B-Instruct/checkpoints"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3430a0b0-05a0-4179-8750-151d492bb9ae",
   "metadata": {},
   "source": [
    "In the code snippet below, the following configurations are worth noting - \n",
    "\n",
    "1. `model.restore_from_path` to the path for the Meta-Llama-3-8B-Instruct.nemo file.\n",
    "2. `model.peft.restore_from_path` to the path for the PEFT checkpoint that was created in the fine-tuning run in the last step.\n",
    "3. `model.test_ds.file_names` to the path of the pubmedqa_test.jsonl file\n",
    "\n",
    "If you have made any changes in model or experiment paths, please ensure they are configured correctly below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "568eb35d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "MODEL=\"./Meta-Llama-3-8B-Instruct.nemo\"\n",
    "TEST_DS=\"[./pubmedqa/data/pubmedqa_test.jsonl]\"\n",
    "TEST_NAMES=\"[pubmedqa]\"\n",
    "SCHEME=\"lora\"\n",
    "TP_SIZE=1\n",
    "PP_SIZE=1\n",
    "\n",
    "# This is where your LoRA checkpoint was saved\n",
    "PATH_TO_TRAINED_MODEL=\"./results/Meta-Llama-3-8B-Instruct/checkpoints/megatron_gpt_peft_lora_tuning.nemo\"\n",
    "\n",
    "# The generation run will save the generated outputs over the test dataset in a file prefixed like so\n",
    "OUTPUT_PREFIX=\"pubmedQA_result_\"\n",
    "\n",
    "python /opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_generate.py \\\n",
    "    model.restore_from_path=${MODEL} \\\n",
    "    model.peft.restore_from_path=${PATH_TO_TRAINED_MODEL} \\\n",
    "    trainer.devices=1 \\\n",
    "    trainer.num_nodes=1 \\\n",
    "    model.data.test_ds.file_names=${TEST_DS} \\\n",
    "    model.data.test_ds.names=${TEST_NAMES} \\\n",
    "    model.data.test_ds.global_batch_size=1 \\\n",
    "    model.data.test_ds.micro_batch_size=1 \\\n",
    "    model.data.test_ds.tokens_to_generate=3 \\\n",
    "    model.tensor_model_parallel_size=${TP_SIZE} \\\n",
    "    model.pipeline_model_parallel_size=${PP_SIZE} \\\n",
    "    inference.greedy=True \\\n",
    "    model.data.test_ds.output_file_path_prefix=${OUTPUT_PREFIX} \\\n",
    "    model.data.test_ds.write_predictions_to_file=True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fe048f9",
   "metadata": {},
   "source": [
    "### Step 4: Check the model accuracy\n",
    "\n",
    "Now that the results are in, let's read the results and calculate the accuracy on the pubmedQA task. You can compare your accuracy results with the public leaderboard at https://pubmedqa.github.io/.\n",
    "\n",
    "Let's take a look at one of the predictions in the generated output file. The `pred` key indicates what was generated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa5c0fdc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!tail -n 1 pubmedQA_result__test_pubmedqa_inputs_preds_labels.jsonl"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1c91df7",
   "metadata": {},
   "source": [
    "Note that the model produces output in the specified format, such as `<<< no >>>`.\n",
    "\n",
    "The following snippet loads the generated output and calculates accuracy in comparison to the test set using the `evaluation.py` script included in the PubMedQA repo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "900f81c2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "answers = []\n",
    "with open(\"pubmedQA_result__test_pubmedqa_inputs_preds_labels.jsonl\",'rt') as f:\n",
    "    st = f.readline()\n",
    "    while st:\n",
    "        answers.append(json.loads(st))\n",
    "        st = f.readline()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74e1bbce",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "data_test = json.load(open(\"./pubmedqa/data/test_set.json\",'rt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a85926e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "results = {}\n",
    "sample_id = list(data_test.keys())\n",
    "\n",
    "for i, key in enumerate(sample_id):\n",
    "    answer = answers[i]['pred']\n",
    "    if 'yes' in answer:\n",
    "        results[key] = 'yes'\n",
    "    elif 'no' in answer:\n",
    "        results[key] = 'no'\n",
    "    elif 'maybe' in answer:\n",
    "        results[key] = 'maybe'\n",
    "    else:\n",
    "        print(\"Malformed answer: \", answer)\n",
    "        results[key] = 'maybe'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fea1a217",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Dump results in a format that can be ingested by PubMedQA evaluation file\n",
    "FILENAME=\"pubmedqa-llama-3-8b-lora.json\"\n",
    "with(open(FILENAME, \"w\")) as f:\n",
    "    json.dump(results, f)\n",
    "\n",
    "# Evaluation\n",
    "!cp $FILENAME ./pubmedqa/\n",
    "!cd ./pubmedqa/ && python evaluation.py $FILENAME"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9909283e-e1f8-450e-a730-403e22f621ad",
   "metadata": {},
   "source": [
    "For the Llama-3-8B-Instruct model, you should see accuracy comparable to the below:\n",
    "```\n",
    "Accuracy 0.792000\n",
    "Macro-F1 0.594778\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
